#!/usr/bin/env python3
#
# Copyright (c) Siemens AG, 2020
#
# Authors:
#  Chao Zeng <chao.zeng@siemens.com>
#  Jan Kiszka <jan.kiszka@siemens.com>
#
# This file is subject to the terms and conditions of the MIT License.  See
# COPYING.MIT file in the top-level directory.
#

import sys
import os
import fcntl
import struct
import mmap
import hashlib
import argparse

force_update = False


def FirmwareUpdateException(str):
    """Report the error info and exit"""
    if not force_update:
        print(str)
        sys.exit(1)


class FirmwareUpdate(object):
    def __init__(self, firmware):
        self.firmware = firmware

    def sha256_check(self):
        """Firmware Sha256 Check"""
        check_sha256_file = self.firmware.name + '.sha256'
        file_sha256 = None
        try:
            with open(check_sha256_file, 'r') as f:
                file_sha256 = str(f.read()).split()[0]
        except FileNotFoundError:
            FirmwareUpdateException("sha256 file not exists")

        try:
            self.firmware.seek(0)
            method = hashlib.sha256()
            method.update(self.firmware.read())
            hash_code = method.hexdigest()
        except IOError as e:
            print("Reading {} failed: {}".format(self.firmware.name, e.strerror))
            sys.exit(1)

        if file_sha256 != hash_code:
            FirmwareUpdateException("Image checksum is not correct")

    def get_path_type_value(self, path):
        """get the path value"""
        try:
            with open(path, "r") as f:
                return f.read()
        except IOError as e:
            print("Reading {} failed: {}".format(path, e.strerror))
            sys.exit(1)

    @staticmethod
    def flash_erase(dev, start, nbytes):
        """This function erases flash sectors
        @dev: flash device file descriptor
        @start: start address
        @nbytes: number of bytes to erase
        """
        MEMERASE = 0x40084d02

        ioctl_data = struct.pack('II', start, nbytes)

        try:
            fcntl.ioctl(dev, MEMERASE, ioctl_data)
        except IOError:
            print("ioctl failed")
            sys.exit(1)

    def cpu_id_check(self):
        """cpu id check"""
        cpu_device_id = [0x142ba, 0x142fa, 0x140ff]
        cpuid_register_addr = 0x43000018
        base_addr = cpuid_register_addr & ~(mmap.PAGESIZE - 1)
        base_addr_offset = cpuid_register_addr - base_addr
        try:
            f = os.open('/dev/mem', os.O_RDWR | os.O_SYNC)
        except FileNotFoundError:
            print("Open /dev/mem Failed")
            sys.exit(1)

        mem = mmap.mmap(f, mmap.PAGESIZE, mmap.MAP_SHARED, mmap.PROT_READ, offset=base_addr)
        mem.seek(base_addr_offset)
        data = []
        data.append(struct.unpack('I', mem.read(4))[0])
        device_id = hex(data[0])
        current_id = int(device_id, 16) >> 11
        if current_id not in cpu_device_id:
            FirmwareUpdateException("Upgrade is not supported for the board")

    def update_firmware(self):
        """Update Firmware"""
        mtd_num = 0

        print("===================================================")
        print("IOT2050 firmware update started - DO NOT INTERRUPT!")
        print("===================================================")

        ospi_dev_path = "/sys/bus/platform/devices/47040000.spi"
        if os.path.exists(ospi_dev_path + "/spi_master"):
            # kernel 5.9 and later
            spi_dev = os.listdir(ospi_dev_path + "/spi_master")[0]
            mtd_base_path = "{}/spi_master/{}/{}.0/mtd".format(ospi_dev_path, spi_dev, spi_dev)
        else:
            # kernel 5.8 and earlier
            mtd_base_path = "{}/mtd".format(ospi_dev_path)

        self.firmware.seek(0)
        firmware_size = os.stat(self.firmware.fileno()).st_size

        while True:
            if firmware_size <= 0:
                print("\nCompleted. Please reboot the device\n")
                sys.exit(0)

            mtd_sys_path = "{}/mtd{}".format(mtd_base_path, mtd_num)
            mtd_name_path = "{}/name".format(mtd_sys_path)
            mtd_size_path = "{}/size".format(mtd_sys_path)
            mtd_erasesize_path = "{}/erasesize".format(mtd_sys_path)
            mtd_dev_path = "/dev/mtd{}".format(mtd_num)

            mtd_size = int(self.get_path_type_value(mtd_size_path))
            mtd_erasesize = int(self.get_path_type_value(mtd_erasesize_path))
            mtd_name = self.get_path_type_value(mtd_name_path).strip()
            mtd_pos = 0

            print("Updating %-20s" % mtd_name, end="")

            try:
                mtd_dev = os.open(mtd_dev_path, os.O_SYNC | os.O_RDWR)
            except IOError as e:
                print("Opening {} failed: {}".format(mtd_dev_path, e.strerror))
                sys.exit(1)

            while mtd_pos < mtd_size:
                mtd_content = os.read(mtd_dev, mtd_erasesize)
                firmware_content = self.firmware.read(mtd_erasesize)

                if not mtd_content == firmware_content:
                    print("U", end="")
                    sys.stdout.flush()
                    self.flash_erase(mtd_dev, mtd_pos, mtd_erasesize)
                    os.lseek(mtd_dev, mtd_pos, os.SEEK_SET)
                    os.write(mtd_dev, firmware_content)
                else:
                    print(".", end="")
                    sys.stdout.flush()
                mtd_pos += mtd_erasesize
                firmware_size -= mtd_erasesize
            print()
            os.close(mtd_dev)

            mtd_num += 1


def main(argv):
    parser = argparse.ArgumentParser(description='Update OSPI firmware.')
    parser.add_argument('firmware', metavar='FIRMWARE',
                        type=argparse.FileType('rb'),
                        help='firmware image')
    parser.add_argument('-f', '--force',
                    help='Force update, ignoring image checksum or device ID mismatches',
                    action='store_true')

    try:
        args = parser.parse_args()
    except IOError as e:
        print(e.strerror, file=sys.stderr)
        exit(1)

    erase_env_input = input("\nWarning: All U-Boot environment variables will be reset to factory settings. Continue (y/N)? ")
    if not erase_env_input == "y":
        sys.exit(1)

    if args.force:
        force_update_input = input("\nWarning: Enforced update may render device unbootable. Continue (y/N)? ")
        if not force_update_input == "y":
            sys.exit(1)
        else:
            global force_update
            force_update = args.force

    firmupdate = FirmwareUpdate(args.firmware)

    firmupdate.sha256_check()

    firmupdate.cpu_id_check()

    firmupdate.update_firmware()


if __name__ == '__main__':
    main(sys.argv)
